/* eslint-disable @typescript-eslint/no-shadow */
import React from "react";
import { StyleSheet, View } from "react-native";
import { Canvas } from "react-native-wgpu";
import type { Mat4 } from "wgpu-matrix";
import { mat4, vec3 } from "wgpu-matrix";

import { useWebGPU } from "../components/useWebGPU";
import { GUI } from "../components/GUI";

import {
  fragmentPrecisionErrorPassWGSL,
  fragmentTextureQuadWGSL,
  fragmentWGSL,
  vertexDepthPrePassWGSL,
  vertexPrecisionErrorPassWGSL,
  vertexTextureQuadWGSL,
  vertexWGSL,
} from "./Shaders";

// Two planes close to each other for depth precision test
const geometryVertexSize = 4 * 8; // Byte size of one geometry vertex.
const geometryPositionOffset = 0;
const geometryColorOffset = 4 * 4; // Byte offset of geometry vertex color attribute.
const geometryDrawCount = 6 * 2;

const d = 0.0001; // half distance between two planes
const o = 0.5; // half x offset to shift planes so they are only partially overlaping

// prettier-ignore
export const geometryVertexArray = new Float32Array([
  // float4 position, float4 color
  -1 - o, -1, d, 1, 1, 0, 0, 1,
  1 - o, -1, d, 1, 1, 0, 0, 1,
  -1 - o, 1, d, 1, 1, 0, 0, 1,
  1 - o, -1, d, 1, 1, 0, 0, 1,
  1 - o, 1, d, 1, 1, 0, 0, 1,
  -1 - o, 1, d, 1, 1, 0, 0, 1,

  -1 + o, -1, -d, 1, 0, 1, 0, 1,
  1 + o, -1, -d, 1, 0, 1, 0, 1,
  -1 + o, 1, -d, 1, 0, 1, 0, 1,
  1 + o, -1, -d, 1, 0, 1, 0, 1,
  1 + o, 1, -d, 1, 0, 1, 0, 1,
  -1 + o, 1, -d, 1, 0, 1, 0, 1,
]);

const xCount = 1;
const yCount = 5;
const numInstances = xCount * yCount;
const matrixFloatCount = 16; // 4x4 matrix
const matrixStride = 4 * matrixFloatCount; // 64;

const depthRangeRemapMatrix = mat4.identity();
depthRangeRemapMatrix[10] = -1;
depthRangeRemapMatrix[14] = 1;

const DepthBufferMode = {
  Default: 0,
  Reversed: 1,
};

const depthBufferModes = [DepthBufferMode.Default, DepthBufferMode.Reversed];
const depthCompareFuncs = {
  [DepthBufferMode.Default]: "less" as GPUCompareFunction,
  [DepthBufferMode.Reversed]: "greater" as GPUCompareFunction,
};
const depthClearValues = {
  [DepthBufferMode.Default]: 1.0,
  [DepthBufferMode.Reversed]: 0.0,
};

export const ReversedZ = () => {
  const ref = useWebGPU(({ context, device, presentationFormat, canvas }) => {
    context.configure({
      device,
      format: presentationFormat,
      alphaMode: "premultiplied",
    });

    const verticesBuffer = device.createBuffer({
      size: geometryVertexArray.byteLength,
      usage: GPUBufferUsage.VERTEX,
      mappedAtCreation: true,
    });
    new Float32Array(verticesBuffer.getMappedRange()).set(geometryVertexArray);
    verticesBuffer.unmap();

    const depthBufferFormat = "depth32float";

    const depthTextureBindGroupLayout = device.createBindGroupLayout({
      entries: [
        {
          binding: 0,
          visibility: GPUShaderStage.FRAGMENT,
          texture: {
            sampleType: "depth",
          },
        },
      ],
    });

    // Model, view, projection matrices
    const uniformBindGroupLayout = device.createBindGroupLayout({
      entries: [
        {
          binding: 0,
          visibility: GPUShaderStage.VERTEX,
          buffer: {
            type: "uniform",
          },
        },
        {
          binding: 1,
          visibility: GPUShaderStage.VERTEX,
          buffer: {
            type: "uniform",
          },
        },
      ],
    });

    const depthPrePassRenderPipelineLayout = device.createPipelineLayout({
      bindGroupLayouts: [uniformBindGroupLayout],
    });

    // depthPrePass is used to render scene to the depth texture
    // this is not needed if you just want to use reversed z to render a scene
    const depthPrePassRenderPipelineDescriptorBase = {
      layout: depthPrePassRenderPipelineLayout,
      vertex: {
        module: device.createShaderModule({
          code: vertexDepthPrePassWGSL,
        }),
        buffers: [
          {
            arrayStride: geometryVertexSize,
            attributes: [
              {
                // position
                shaderLocation: 0,
                offset: geometryPositionOffset,
                format: "float32x4",
              },
            ],
          },
        ],
      },
      primitive: {
        topology: "triangle-list",
        cullMode: "back",
      },
      depthStencil: {
        depthWriteEnabled: true,
        depthCompare: "less",
        format: depthBufferFormat,
      },
    } as GPURenderPipelineDescriptor;

    // we need the depthCompare to fit the depth buffer mode we are using.
    // this is the same for other passes
    const depthPrePassPipelines: GPURenderPipeline[] = [];
    // eslint-disable-next-line @typescript-eslint/ban-ts-comment
    // @ts-expect-error
    depthPrePassRenderPipelineDescriptorBase.depthStencil.depthCompare =
      depthCompareFuncs[DepthBufferMode.Default];
    depthPrePassPipelines[DepthBufferMode.Default] =
      device.createRenderPipeline(depthPrePassRenderPipelineDescriptorBase);
    // eslint-disable-next-line @typescript-eslint/ban-ts-comment
    // @ts-expect-error
    depthPrePassRenderPipelineDescriptorBase.depthStencil.depthCompare =
      depthCompareFuncs[DepthBufferMode.Reversed];
    depthPrePassPipelines[DepthBufferMode.Reversed] =
      device.createRenderPipeline(depthPrePassRenderPipelineDescriptorBase);

    // precisionPass is to draw precision error as color of depth value stored in depth buffer
    // compared to that directly calcualated in the shader
    const precisionPassRenderPipelineLayout = device.createPipelineLayout({
      bindGroupLayouts: [uniformBindGroupLayout, depthTextureBindGroupLayout],
    });
    const precisionPassRenderPipelineDescriptorBase = {
      layout: precisionPassRenderPipelineLayout,
      vertex: {
        module: device.createShaderModule({
          code: vertexPrecisionErrorPassWGSL,
        }),
        buffers: [
          {
            arrayStride: geometryVertexSize,
            attributes: [
              {
                // position
                shaderLocation: 0,
                offset: geometryPositionOffset,
                format: "float32x4",
              },
            ],
          },
        ],
      },
      fragment: {
        module: device.createShaderModule({
          code: fragmentPrecisionErrorPassWGSL,
        }),
        targets: [
          {
            format: presentationFormat,
          },
        ],
      },
      primitive: {
        topology: "triangle-list",
        cullMode: "back",
      },
      depthStencil: {
        depthWriteEnabled: true,
        depthCompare: "less",
        format: depthBufferFormat,
      },
    } as GPURenderPipelineDescriptor;
    const precisionPassPipelines: GPURenderPipeline[] = [];
    // eslint-disable-next-line @typescript-eslint/ban-ts-comment
    // @ts-expect-error
    precisionPassRenderPipelineDescriptorBase.depthStencil.depthCompare =
      depthCompareFuncs[DepthBufferMode.Default];
    precisionPassPipelines[DepthBufferMode.Default] =
      device.createRenderPipeline(precisionPassRenderPipelineDescriptorBase);
    // eslint-disable-next-line @typescript-eslint/ban-ts-comment
    // @ts-expect-error
    precisionPassRenderPipelineDescriptorBase.depthStencil.depthCompare =
      depthCompareFuncs[DepthBufferMode.Reversed];
    // prettier-ignore
    precisionPassPipelines[DepthBufferMode.Reversed] = device.createRenderPipeline(
  precisionPassRenderPipelineDescriptorBase
);

    // colorPass is the regular render pass to render the scene
    const colorPassRenderPiplineLayout = device.createPipelineLayout({
      bindGroupLayouts: [uniformBindGroupLayout],
    });
    const colorPassRenderPipelineDescriptorBase: GPURenderPipelineDescriptor = {
      layout: colorPassRenderPiplineLayout,
      vertex: {
        module: device.createShaderModule({
          code: vertexWGSL,
        }),
        buffers: [
          {
            arrayStride: geometryVertexSize,
            attributes: [
              {
                // position
                shaderLocation: 0,
                offset: geometryPositionOffset,
                format: "float32x4",
              },
              {
                // color
                shaderLocation: 1,
                offset: geometryColorOffset,
                format: "float32x4",
              },
            ],
          },
        ],
      },
      fragment: {
        module: device.createShaderModule({
          code: fragmentWGSL,
        }),
        targets: [
          {
            format: presentationFormat,
          },
        ],
      },
      primitive: {
        topology: "triangle-list",
        cullMode: "back",
      },
      depthStencil: {
        depthWriteEnabled: true,
        depthCompare: "less",
        format: depthBufferFormat,
      },
    };
    const colorPassPipelines: GPURenderPipeline[] = [];
    // eslint-disable-next-line @typescript-eslint/ban-ts-comment
    // @ts-expect-error
    colorPassRenderPipelineDescriptorBase.depthStencil.depthCompare =
      depthCompareFuncs[DepthBufferMode.Default];
    colorPassPipelines[DepthBufferMode.Default] = device.createRenderPipeline(
      colorPassRenderPipelineDescriptorBase,
    );
    // eslint-disable-next-line @typescript-eslint/ban-ts-comment
    // @ts-expect-error
    colorPassRenderPipelineDescriptorBase.depthStencil.depthCompare =
      depthCompareFuncs[DepthBufferMode.Reversed];
    colorPassPipelines[DepthBufferMode.Reversed] = device.createRenderPipeline(
      colorPassRenderPipelineDescriptorBase,
    );

    // textureQuadPass is draw a full screen quad of depth texture
    // to see the difference of depth value using reversed z compared to default depth buffer usage
    // 0.0 will be the furthest and 1.0 will be the closest
    const textureQuadPassPiplineLayout = device.createPipelineLayout({
      bindGroupLayouts: [depthTextureBindGroupLayout],
    });
    const textureQuadPassPipline = device.createRenderPipeline({
      layout: textureQuadPassPiplineLayout,
      vertex: {
        module: device.createShaderModule({
          code: vertexTextureQuadWGSL,
        }),
      },
      fragment: {
        module: device.createShaderModule({
          code: fragmentTextureQuadWGSL,
        }),
        targets: [
          {
            format: presentationFormat,
          },
        ],
      },
      primitive: {
        topology: "triangle-list",
      },
    });

    const depthTexture = device.createTexture({
      size: [canvas.width, canvas.height],
      format: depthBufferFormat,
      usage:
        GPUTextureUsage.RENDER_ATTACHMENT | GPUTextureUsage.TEXTURE_BINDING,
    });
    const depthTextureView = depthTexture.createView();

    const defaultDepthTexture = device.createTexture({
      size: [canvas.width, canvas.height],
      format: depthBufferFormat,
      usage: GPUTextureUsage.RENDER_ATTACHMENT,
    });
    const defaultDepthTextureView = defaultDepthTexture.createView();

    const depthPrePassDescriptor: GPURenderPassDescriptor = {
      colorAttachments: [],
      depthStencilAttachment: {
        view: depthTextureView,

        depthClearValue: 1.0,
        depthLoadOp: "clear",
        depthStoreOp: "store",
      },
    };

    // drawPassDescriptor and drawPassLoadDescriptor are used for drawing
    // the scene twice using different depth buffer mode on splitted viewport
    // of the same canvas
    // see the difference of the loadOp of the colorAttachments
    const drawPassDescriptor: GPURenderPassDescriptor = {
      // eslint-disable-next-line @typescript-eslint/ban-ts-comment
      // @ts-expect-error
      colorAttachments: [
        {
          // view is acquired and set in render loop.
          view: undefined,

          clearValue: [0.0, 0.0, 0.5, 1.0],
          loadOp: "clear",
          storeOp: "store",
        },
      ],
      depthStencilAttachment: {
        view: defaultDepthTextureView,

        depthClearValue: 1.0,
        depthLoadOp: "clear",
        depthStoreOp: "store",
      },
    };
    const drawPassLoadDescriptor: GPURenderPassDescriptor = {
      // eslint-disable-next-line @typescript-eslint/ban-ts-comment
      // @ts-expect-error
      colorAttachments: [
        {
          // attachment is acquired and set in render loop.
          view: undefined,

          loadOp: "load",
          storeOp: "store",
        },
      ],
      depthStencilAttachment: {
        view: defaultDepthTextureView,

        depthClearValue: 1.0,
        depthLoadOp: "clear",
        depthStoreOp: "store",
      },
    };
    const drawPassDescriptors = [drawPassDescriptor, drawPassLoadDescriptor];

    const textureQuadPassDescriptor: GPURenderPassDescriptor = {
      // eslint-disable-next-line @typescript-eslint/ban-ts-comment
      // @ts-expect-error
      colorAttachments: [
        {
          // view is acquired and set in render loop.
          view: undefined,

          clearValue: [0.0, 0.0, 0.5, 1.0],
          loadOp: "clear",
          storeOp: "store",
        },
      ],
    };
    const textureQuadPassLoadDescriptor: GPURenderPassDescriptor = {
      // eslint-disable-next-line @typescript-eslint/ban-ts-comment
      // @ts-expect-error
      colorAttachments: [
        {
          // view is acquired and set in render loop.
          view: undefined,

          loadOp: "load",
          storeOp: "store",
        },
      ],
    };
    const textureQuadPassDescriptors = [
      textureQuadPassDescriptor,
      textureQuadPassLoadDescriptor,
    ];

    const depthTextureBindGroup = device.createBindGroup({
      layout: depthTextureBindGroupLayout,
      entries: [
        {
          binding: 0,
          resource: depthTextureView,
        },
      ],
    });

    const uniformBufferSize = numInstances * matrixStride;

    const uniformBuffer = device.createBuffer({
      size: uniformBufferSize,
      usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
    });
    const cameraMatrixBuffer = device.createBuffer({
      size: 4 * 16, // 4x4 matrix
      usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
    });
    const cameraMatrixReversedDepthBuffer = device.createBuffer({
      size: 4 * 16, // 4x4 matrix
      usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
    });

    const uniformBindGroups = [
      device.createBindGroup({
        layout: uniformBindGroupLayout,
        entries: [
          {
            binding: 0,
            resource: {
              buffer: uniformBuffer,
            },
          },
          {
            binding: 1,
            resource: {
              buffer: cameraMatrixBuffer,
            },
          },
        ],
      }),
      device.createBindGroup({
        layout: uniformBindGroupLayout,
        entries: [
          {
            binding: 0,
            resource: {
              buffer: uniformBuffer,
            },
          },
          {
            binding: 1,
            resource: {
              buffer: cameraMatrixReversedDepthBuffer,
            },
          },
        ],
      }),
    ];

    const modelMatrices = new Array<Mat4>(numInstances);
    const mvpMatricesData = new Float32Array(matrixFloatCount * numInstances);

    let m = 0;
    for (let x = 0; x < xCount; x++) {
      for (let y = 0; y < yCount; y++) {
        const z = -800 * m;
        const s = 1 + 50 * m;

        modelMatrices[m] = mat4.translation(
          vec3.fromValues(
            x - xCount / 2 + 0.5,
            (4.0 - 0.2 * z) * (y - yCount / 2 + 1.0),
            z,
          ),
        );
        mat4.scale(
          modelMatrices[m],
          vec3.fromValues(s, s, s),
          modelMatrices[m],
        );

        m++;
      }
    }

    const viewMatrix = mat4.translation(vec3.fromValues(0, 0, -12));

    const aspect = (0.5 * canvas.width) / canvas.height;
    // wgpu-matrix perspective doesn't handle zFar === Infinity now.
    // https://github.com/greggman/wgpu-matrix/issues/9
    const projectionMatrix = mat4.perspective(
      (2 * Math.PI) / 5,
      aspect,
      5,
      9999,
    );

    const viewProjectionMatrix = mat4.multiply(projectionMatrix, viewMatrix);
    // to use 1/z we just multiple depthRangeRemapMatrix to our default camera view projection matrix
    const reversedRangeViewProjectionMatrix = mat4.multiply(
      depthRangeRemapMatrix,
      viewProjectionMatrix,
    );

    device.queue.writeBuffer(cameraMatrixBuffer, 0, viewProjectionMatrix);
    device.queue.writeBuffer(
      cameraMatrixReversedDepthBuffer,
      0,
      reversedRangeViewProjectionMatrix,
    );

    const tmpMat4 = mat4.create();
    function updateTransformationMatrix() {
      const now = Date.now() / 1000;

      for (let i = 0, m = 0; i < numInstances; i++, m += matrixFloatCount) {
        mat4.rotate(
          modelMatrices[i],
          vec3.fromValues(Math.sin(now), Math.cos(now), 0),
          (Math.PI / 180) * 30,
          tmpMat4,
        );
        mvpMatricesData.set(tmpMat4, m);
      }
    }

    const settings = {
      mode: "color",
    };
    const gui = new GUI();
    gui.add(settings, "mode", ["color", "precision-error", "depth-texture"]);

    function frame() {
      updateTransformationMatrix();
      device.queue.writeBuffer(
        uniformBuffer,
        0,
        mvpMatricesData.buffer,
        mvpMatricesData.byteOffset,
        mvpMatricesData.byteLength,
      );

      const attachment = context.getCurrentTexture().createView();
      const commandEncoder = device.createCommandEncoder();
      if (settings.mode === "color") {
        for (const m of depthBufferModes) {
          // eslint-disable-next-line @typescript-eslint/ban-ts-comment
          // @ts-expect-error
          drawPassDescriptors[m].colorAttachments[0].view = attachment;
          // eslint-disable-next-line @typescript-eslint/ban-ts-comment
          // @ts-expect-error
          drawPassDescriptors[m].depthStencilAttachment.depthClearValue =
            depthClearValues[m];
          const colorPass = commandEncoder.beginRenderPass(
            drawPassDescriptors[m],
          );
          colorPass.setPipeline(colorPassPipelines[m]);
          colorPass.setBindGroup(0, uniformBindGroups[m]);
          colorPass.setVertexBuffer(0, verticesBuffer);
          colorPass.setViewport(
            (canvas.width * m) / 2,
            0,
            canvas.width / 2,
            canvas.height,
            0,
            1,
          );
          colorPass.draw(geometryDrawCount, numInstances, 0, 0);
          colorPass.end();
        }
      } else if (settings.mode === "precision-error") {
        for (const m of depthBufferModes) {
          {
            // eslint-disable-next-line @typescript-eslint/ban-ts-comment
            // @ts-expect-error
            depthPrePassDescriptor.depthStencilAttachment.depthClearValue =
              depthClearValues[m];
            const depthPrePass = commandEncoder.beginRenderPass(
              depthPrePassDescriptor,
            );
            depthPrePass.setPipeline(depthPrePassPipelines[m]);
            depthPrePass.setBindGroup(0, uniformBindGroups[m]);
            depthPrePass.setVertexBuffer(0, verticesBuffer);
            depthPrePass.setViewport(
              (canvas.width * m) / 2,
              0,
              canvas.width / 2,
              canvas.height,
              0,
              1,
            );
            depthPrePass.draw(geometryDrawCount, numInstances, 0, 0);
            depthPrePass.end();
          }
          {
            // eslint-disable-next-line @typescript-eslint/ban-ts-comment
            // @ts-expect-error
            drawPassDescriptors[m].colorAttachments[0].view = attachment;
            // eslint-disable-next-line @typescript-eslint/ban-ts-comment
            // @ts-expect-error
            drawPassDescriptors[m].depthStencilAttachment.depthClearValue =
              depthClearValues[m];
            const precisionErrorPass = commandEncoder.beginRenderPass(
              drawPassDescriptors[m],
            );
            precisionErrorPass.setPipeline(precisionPassPipelines[m]);
            precisionErrorPass.setBindGroup(0, uniformBindGroups[m]);
            precisionErrorPass.setBindGroup(1, depthTextureBindGroup);
            precisionErrorPass.setVertexBuffer(0, verticesBuffer);
            precisionErrorPass.setViewport(
              (canvas.width * m) / 2,
              0,
              canvas.width / 2,
              canvas.height,
              0,
              1,
            );
            precisionErrorPass.draw(geometryDrawCount, numInstances, 0, 0);
            precisionErrorPass.end();
          }
        }
      } else {
        // depth texture quad
        for (const m of depthBufferModes) {
          {
            // eslint-disable-next-line @typescript-eslint/ban-ts-comment
            // @ts-expect-error
            depthPrePassDescriptor.depthStencilAttachment.depthClearValue =
              depthClearValues[m];
            const depthPrePass = commandEncoder.beginRenderPass(
              depthPrePassDescriptor,
            );
            depthPrePass.setPipeline(depthPrePassPipelines[m]);
            depthPrePass.setBindGroup(0, uniformBindGroups[m]);
            depthPrePass.setVertexBuffer(0, verticesBuffer);
            depthPrePass.setViewport(
              (canvas.width * m) / 2,
              0,
              canvas.width / 2,
              canvas.height,
              0,
              1,
            );
            depthPrePass.draw(geometryDrawCount, numInstances, 0, 0);
            depthPrePass.end();
          }
          {
            // eslint-disable-next-line @typescript-eslint/ban-ts-comment
            // @ts-expect-error
            textureQuadPassDescriptors[m].colorAttachments[0].view = attachment;
            const depthTextureQuadPass = commandEncoder.beginRenderPass(
              textureQuadPassDescriptors[m],
            );
            depthTextureQuadPass.setPipeline(textureQuadPassPipline);
            depthTextureQuadPass.setBindGroup(0, depthTextureBindGroup);
            depthTextureQuadPass.setViewport(
              (canvas.width * m) / 2,
              0,
              canvas.width / 2,
              canvas.height,
              0,
              1,
            );
            depthTextureQuadPass.draw(6);
            depthTextureQuadPass.end();
          }
        }
      }
      device.queue.submit([commandEncoder.finish()]);
    }
    return frame;
  });

  return (
    <View style={style.container}>
      <Canvas ref={ref} style={style.webgpu} />
    </View>
  );
};

const style = StyleSheet.create({
  container: {
    flex: 1,
  },
  webgpu: {
    flex: 1,
  },
});
